Useful Features of PyTorch Lightning's CheckpointCallback
A note on the feature that lets you save anything in a checkpoint file during training.
on_save_checkpoint
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
By using this, you can store information needed at inference time in the checkpoint alongside the parameters, saving you the trouble of manually loading it or saving it in a separate file at inference time.
def on_load_checkpoint
You can load it with this:
def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Use case
You can save things like the covariance matrix of the training data together with the model weights.